#!/usr/bin/env python3
"""
Phase 8: Section 9.2 Hypothesis Tests — Why the Demographic Channel Weakens Post-Opening
========================================================================================
Converts Section 9.2's three discussion-only mechanisms into formal tests:

TEST 1 — Disruption of Structural Confounders
  (a) Partial R² decomposition: Z₁ predicted by observables, pre vs post
  (b) Mediation test: add structural controls to Z₁→CA, pre vs post
  (c) Oster (2019) δ bounds for pre vs post

TEST 2 — Countervailing Capital Flows
  (a) Gross flow decomposition: Z₁ on gross assets/liabilities, pre vs post
  (b) CA component split: savings-investment gap, pre vs post
  (c) Reserve accumulation proxy for capital flight

TEST 3 — Composition Effects Across Cohorts
  (a) Cohort-specific pre/post Z₁ coefficients (early vs late openers)
  (b) Z₁ × event_time continuous interaction (gradual vs discrete fade)
  (c) Demographic trajectory divergence across cohorts

Output:
  phase8_confounder_partial_r2.md — Partial R² of Z₁ on observables, pre vs post
  phase8_confounder_mediation.md — Z₁ attenuation with structural controls, pre vs post
  phase8_oster_bounds.md — Oster δ bounds, pre vs post
  phase8_gross_flows.md — Z₁ on gross asset/liability components, pre vs post
  phase8_ca_components.md — Z₁ on CA components, pre vs post
  phase8_cohort_attenuation.md — Cohort-specific Z₁ pre/post coefficients
  phase8_event_time_interaction.md — Z₁ × event_time continuous fade
  phase8_summary.md — Summary table of all three mechanism tests with verdicts
"""

import sys
import pandas as pd
import numpy as np
from pathlib import Path

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows")
MULTILATERAL_DIR = PROJECT_DIR / "multilateral"
CAUSAL_DIR = PROJECT_DIR / "causal_identification"
PROCESSED_DIR = CAUSAL_DIR / "data" / "processed"
OUTPUT_DIR = CAUSAL_DIR / "output" / "tables"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

sys.path.insert(0, str(MULTILATERAL_DIR))
from src.model import PanelGLS

DEMO_VARS = ['Z_1', 'Z_2', 'Z_3']
CONTROLS = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'trade_openness', 'log_rel_opw']

# Structural confounders — Soviet-era proxies
STRUCTURAL_CONTROLS = ['rule_of_law', 'control_corruption', 'gross_savings_gdp', 'fdi_liab_gdp']
Z1_PREDICTORS = ['trade_openness', 'rule_of_law', 'control_corruption', 'nfa_gdp_lag', 'gross_savings_gdp']

# Gross flow components
GROSS_FLOW_VARS = [
    ('gross_assets_gdp', 'Gross Assets'),
    ('gross_liab_gdp', 'Gross Liabilities'),
    ('fdi_assets_gdp', 'FDI Assets'),
    ('fdi_liab_gdp', 'FDI Liabilities'),
    ('debt_assets_gdp', 'Debt Assets'),
    ('debt_liab_gdp', 'Debt Liabilities'),
    ('fx_reserves_gdp', 'FX Reserves'),
]

# CA components
CA_COMPONENT_VARS = [
    ('ca_gdp', 'Current Account'),
    ('savings_investment_gap', 'S-I Gap'),
    ('gross_savings_gdp', 'Gross Savings'),
    ('gross_investment_gdp_x', 'Gross Investment'),
]


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.1: return '*'
    return ''


def fv(v, p=None):
    """Format value, optionally with significance stars."""
    if pd.isna(v): return '--'
    s = f"{v:.3f}"
    if p is not None and not pd.isna(p):
        s += stars(p)
    return s


def load_panel():
    """Load causal panel."""
    df = pd.read_csv(PROCESSED_DIR / "causal_panel.csv", low_memory=False)
    df = df[(df['year'] >= 1992) & (df['year'] <= 2024)].copy()
    return df


def run_gls(df, dep_var, indep_vars, label=""):
    """Run PanelGLS and return results dict."""
    available = [v for v in indep_vars if v in df.columns]
    comp = df.dropna(subset=[dep_var] + available).copy()
    if comp['iso3'].nunique() < 3 or len(comp) < 30:
        print(f"  {label}: insufficient obs ({len(comp)})")
        return None
    y = comp[dep_var].values
    X = comp[available].values
    gls = PanelGLS()
    gls.fit(y, X, comp['iso3'].values, comp['year'].values)
    result = {
        'label': label,
        'r_squared': gls.r_squared,
        'n_obs': gls.n_obs,
        'n_countries': gls.n_countries,
    }
    for i, v in enumerate(available):
        result[f'{v}_coef'] = gls.beta[i]
        result[f'{v}_se'] = gls.se[i]
        result[f'{v}_pval'] = gls.pvalues[i]
    return result


def get_transition_subsamples(df):
    """Split transition data into pre/post opening subsamples."""
    trans = df[df['is_transition'] == 1].copy()
    openers = trans[trans['status'] == 'opener'].copy()
    pre = openers[openers['event_time'] < 0].copy()
    post = openers[openers['event_time'] >= 0].copy()
    return trans, openers, pre, post


# =====================================================================
# TEST 1: DISRUPTION OF STRUCTURAL CONFOUNDERS
# =====================================================================

def test1a_partial_r2(df):
    """Partial R² decomposition: how well do observables predict Z₁, pre vs post."""
    print("\n" + "=" * 70)
    print("TEST 1a: PARTIAL R² — Z₁ PREDICTED BY OBSERVABLES")
    print("=" * 70)

    _, _, pre, post = get_transition_subsamples(df)

    results = []
    for period_label, period_df in [('Pre-opening', pre), ('Post-opening', post)]:
        available = [v for v in Z1_PREDICTORS if v in period_df.columns]
        comp = period_df.dropna(subset=['Z_1'] + available)
        if len(comp) < 30:
            print(f"  {period_label}: insufficient obs ({len(comp)})")
            continue

        # Full model: Z₁ ~ all predictors
        r_full = run_gls(comp, 'Z_1', available, f"{period_label}: Full")
        if r_full is None:
            continue

        row = {
            'period': period_label,
            'r2_full': r_full['r_squared'],
            'n_obs': r_full['n_obs'],
            'n_countries': r_full['n_countries'],
        }

        # Individual predictor R²
        for v in available:
            r_single = run_gls(comp, 'Z_1', [v], f"{period_label}: {v}")
            if r_single:
                row[f'r2_{v}'] = r_single['r_squared']

        results.append(row)
        print(f"  {period_label}: R²(full) = {r_full['r_squared']:.4f}, "
              f"N = {r_full['n_obs']}, Countries = {r_full['n_countries']}")

    # Write table
    lines = ["# Test 1a: Partial R² — Z₁ Predicted by Observables\n"]
    lines.append("*How well do Soviet-era structural proxies predict Z₁, pre vs post opening?*\n")
    lines.append("*If confounders hypothesis is correct: pre R² >> post R².*\n")

    available = [v for v in Z1_PREDICTORS if v in df.columns]
    header = "| Period | R² (full) | " + " | ".join(f"R²({v})" for v in available) + " | N | Countries |"
    sep = "|:---|---:|" + "---:|" * len(available) + "---:|---:|"
    lines.append(header)
    lines.append(sep)

    for r in results:
        row_parts = [r['period'], fv(r['r2_full'])]
        for v in available:
            row_parts.append(fv(r.get(f'r2_{v}', np.nan)))
        row_parts.extend([str(r['n_obs']), str(r['n_countries'])])
        lines.append("| " + " | ".join(row_parts) + " |")

    if len(results) == 2:
        r2_pre = results[0]['r2_full']
        r2_post = results[1]['r2_full']
        ratio = r2_pre / r2_post if r2_post > 0 else np.inf
        lines.append(f"\n**Pre/Post R² ratio: {ratio:.2f}×**")
        if r2_pre > r2_post:
            lines.append("*Observables predict Z₁ better pre-opening → consistent with confounders hypothesis.*")
        else:
            lines.append("*Observables predict Z₁ comparably or better post-opening → against confounders hypothesis.*")

    lines.append("\n*Dependent variable: Z₁. PanelGLS with entity and time effects.*")
    lines.append("*Predictors: trade openness, rule of law, control of corruption, lagged NFA, gross savings.*")

    path = OUTPUT_DIR / "phase8_confounder_partial_r2.md"
    path.write_text('\n'.join(lines))
    print(f"  Saved: {path}")
    return results


def test1b_mediation(df):
    """Mediation test: add structural controls to Z₁→CA regression, pre vs post."""
    print("\n" + "=" * 70)
    print("TEST 1b: MEDIATION — STRUCTURAL CONTROLS ATTENUATE Z₁")
    print("=" * 70)

    _, _, pre, post = get_transition_subsamples(df)

    results = []
    for period_label, period_df in [('Pre-opening', pre), ('Post-opening', post)]:
        # Baseline: Z + standard controls
        r_base = run_gls(period_df, 'ca_gdp', DEMO_VARS + CONTROLS,
                         f"{period_label}: Baseline")
        if r_base is None:
            continue

        # Mediated: Z + standard controls + structural controls
        r_med = run_gls(period_df, 'ca_gdp', DEMO_VARS + CONTROLS + STRUCTURAL_CONTROLS,
                        f"{period_label}: +Structural")
        if r_med is None:
            continue

        z1_base = r_base.get('Z_1_coef', np.nan)
        z1_med = r_med.get('Z_1_coef', np.nan)
        attenuation = ((z1_base - z1_med) / z1_base * 100) if z1_base != 0 else np.nan

        row = {
            'period': period_label,
            'z1_baseline': z1_base,
            'z1_baseline_se': r_base.get('Z_1_se', np.nan),
            'z1_baseline_p': r_base.get('Z_1_pval', np.nan),
            'r2_baseline': r_base['r_squared'],
            'z1_mediated': z1_med,
            'z1_mediated_se': r_med.get('Z_1_se', np.nan),
            'z1_mediated_p': r_med.get('Z_1_pval', np.nan),
            'r2_mediated': r_med['r_squared'],
            'attenuation_pct': attenuation,
            'n_obs_base': r_base['n_obs'],
            'n_obs_med': r_med['n_obs'],
        }
        results.append(row)

        print(f"  {period_label}:")
        print(f"    Baseline Z₁ = {z1_base:.3f} (p={r_base.get('Z_1_pval', np.nan):.4f})")
        print(f"    Mediated Z₁ = {z1_med:.3f} (p={r_med.get('Z_1_pval', np.nan):.4f})")
        print(f"    Attenuation = {attenuation:.1f}%")

    # Write table
    lines = ["# Test 1b: Mediation — Structural Controls Attenuate Z₁\n"]
    lines.append("*If confounders drive the pre-opening Z₁ effect, adding structural controls "
                 "should attenuate Z₁ more in pre-opening than post-opening.*\n")

    lines.append("| Period | Z₁ (baseline) | SE | p | R² | Z₁ (+structural) | SE | p | R² | Attenuation | N |")
    lines.append("|:---|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|")

    for r in results:
        lines.append(
            f"| {r['period']} "
            f"| {fv(r['z1_baseline'], r['z1_baseline_p'])} "
            f"| ({fv(r['z1_baseline_se'])}) "
            f"| {fv(r['z1_baseline_p'])} "
            f"| {fv(r['r2_baseline'])} "
            f"| {fv(r['z1_mediated'], r['z1_mediated_p'])} "
            f"| ({fv(r['z1_mediated_se'])}) "
            f"| {fv(r['z1_mediated_p'])} "
            f"| {fv(r['r2_mediated'])} "
            f"| {r['attenuation_pct']:.1f}% "
            f"| {r['n_obs_base']}/{r['n_obs_med']} |"
        )

    if len(results) == 2:
        att_pre = results[0]['attenuation_pct']
        att_post = results[1]['attenuation_pct']
        lines.append(f"\n**Pre-opening attenuation: {att_pre:.1f}%. Post-opening attenuation: {att_post:.1f}%.**")
        if att_pre > att_post:
            lines.append("*Greater attenuation pre-opening → confounders drive pre-opening Z₁ effect.*")
        else:
            lines.append("*Attenuation comparable or greater post-opening → against confounders hypothesis.*")

    lines.append("\n*Baseline: ca_gdp ~ Z₁ Z₂ Z₃ + fiscal_bal_gdp + nfa_gdp_lag + trade_openness + log_rel_opw.*")
    lines.append("*+Structural: adds rule_of_law, control_corruption, gross_savings_gdp, fdi_liab_gdp.*")
    lines.append("*Transition economy openers only. PanelGLS with entity and time effects.*")

    path = OUTPUT_DIR / "phase8_confounder_mediation.md"
    path.write_text('\n'.join(lines))
    print(f"  Saved: {path}")
    return results


def test1c_oster_bounds(df):
    """Oster (2019) δ bounds for pre vs post subsamples."""
    print("\n" + "=" * 70)
    print("TEST 1c: OSTER (2019) δ BOUNDS")
    print("=" * 70)

    _, _, pre, post = get_transition_subsamples(df)

    results = []
    for period_label, period_df in [('Pre-opening', pre), ('Post-opening', post)]:
        # Uncontrolled: ca_gdp ~ Z₁ Z₂ Z₃
        r_unc = run_gls(period_df, 'ca_gdp', DEMO_VARS,
                        f"{period_label}: Uncontrolled")
        # Controlled: ca_gdp ~ Z₁ Z₂ Z₃ + all controls + structural
        all_controls = CONTROLS + STRUCTURAL_CONTROLS
        r_ctrl = run_gls(period_df, 'ca_gdp', DEMO_VARS + all_controls,
                         f"{period_label}: Controlled")

        if r_unc is None or r_ctrl is None:
            continue

        beta_unc = r_unc.get('Z_1_coef', np.nan)
        beta_ctrl = r_ctrl.get('Z_1_coef', np.nan)
        r2_unc = r_unc['r_squared']
        r2_ctrl = r_ctrl['r_squared']
        r2_max = min(1.3 * r2_ctrl, 1.0)

        # Oster δ formula
        denom = (beta_unc - beta_ctrl) * (r2_ctrl - r2_unc)
        if abs(denom) > 1e-10:
            delta = (beta_ctrl * (r2_max - r2_ctrl)) / denom
        else:
            delta = np.nan

        row = {
            'period': period_label,
            'beta_uncontrolled': beta_unc,
            'beta_controlled': beta_ctrl,
            'r2_uncontrolled': r2_unc,
            'r2_controlled': r2_ctrl,
            'r2_max': r2_max,
            'delta': delta,
            'n_obs': r_ctrl['n_obs'],
        }
        results.append(row)

        print(f"  {period_label}:")
        print(f"    β_unc = {beta_unc:.3f}, β_ctrl = {beta_ctrl:.3f}")
        print(f"    R²_unc = {r2_unc:.4f}, R²_ctrl = {r2_ctrl:.4f}, R²_max = {r2_max:.4f}")
        print(f"    δ = {delta:.3f}")

    # Write table
    lines = ["# Test 1c: Oster (2019) δ Bounds\n"]
    lines.append("*δ measures proportional selection on unobservables vs observables needed "
                 "to drive β to zero.*\n")
    lines.append("*If δ_pre < δ_post, unobservables matter more pre-opening (confounders hypothesis).*\n")
    lines.append("*|δ| > 1 means unobservables would need to be MORE important than observables "
                 "to explain away the result — robust.*\n")

    lines.append("| Period | β (uncontrolled) | β (controlled) | R² (unc) | R² (ctrl) | R²_max | δ | N |")
    lines.append("|:---|---:|---:|---:|---:|---:|---:|---:|")

    for r in results:
        lines.append(
            f"| {r['period']} "
            f"| {fv(r['beta_uncontrolled'])} "
            f"| {fv(r['beta_controlled'])} "
            f"| {fv(r['r2_uncontrolled'])} "
            f"| {fv(r['r2_controlled'])} "
            f"| {fv(r['r2_max'])} "
            f"| {fv(r['delta'])} "
            f"| {r['n_obs']} |"
        )

    if len(results) == 2:
        d_pre = results[0]['delta']
        d_post = results[1]['delta']
        lines.append(f"\n**δ_pre = {fv(d_pre)}, δ_post = {fv(d_post)}**")
        if not pd.isna(d_pre) and not pd.isna(d_post):
            if abs(d_pre) < abs(d_post):
                lines.append("*|δ_pre| < |δ_post| → pre-opening result more vulnerable to "
                             "unobservables → consistent with confounders hypothesis.*")
            else:
                lines.append("*|δ_pre| ≥ |δ_post| → pre-opening result NOT more vulnerable "
                             "to unobservables → against confounders hypothesis.*")

    lines.append("\n*R²_max = min(1.3 × R²_ctrl, 1.0) following Oster (2019).*")
    lines.append("*Uncontrolled: ca_gdp ~ Z. Controlled: ca_gdp ~ Z + all controls + structural.*")
    lines.append("*Transition economy openers only.*")

    path = OUTPUT_DIR / "phase8_oster_bounds.md"
    path.write_text('\n'.join(lines))
    print(f"  Saved: {path}")
    return results


# =====================================================================
# TEST 2: COUNTERVAILING CAPITAL FLOWS
# =====================================================================

def test2a_gross_flows(df):
    """Gross flow decomposition: Z₁ on gross assets/liabilities, pre vs post."""
    print("\n" + "=" * 70)
    print("TEST 2a: GROSS FLOW DECOMPOSITION")
    print("=" * 70)

    _, _, pre, post = get_transition_subsamples(df)

    results = []
    for dep_var, dep_label in GROSS_FLOW_VARS:
        if dep_var not in df.columns:
            continue
        for period_label, period_df in [('Pre', pre), ('Post', post)]:
            r = run_gls(period_df, dep_var, DEMO_VARS + CONTROLS,
                        f"{period_label}: {dep_label}")
            if r:
                results.append({
                    'dep_var': dep_label,
                    'period': period_label,
                    'z1_coef': r.get('Z_1_coef', np.nan),
                    'z1_se': r.get('Z_1_se', np.nan),
                    'z1_p': r.get('Z_1_pval', np.nan),
                    'z2_coef': r.get('Z_2_coef', np.nan),
                    'z2_p': r.get('Z_2_pval', np.nan),
                    'r2': r['r_squared'],
                    'n_obs': r['n_obs'],
                })

    # Write table
    lines = ["# Test 2a: Gross Flow Decomposition — Z₁ on Components, Pre vs Post\n"]
    lines.append("*If countervailing flows: Z₁ should predict assets and liabilities "
                 "in offsetting directions post-opening.*\n")

    lines.append("| Dependent Variable | Period | Z₁ | SE | p | Z₂ | p | R² | N |")
    lines.append("|:---|:---|---:|---:|---:|---:|---:|---:|---:|")

    for r in results:
        lines.append(
            f"| {r['dep_var']} | {r['period']} "
            f"| {fv(r['z1_coef'], r['z1_p'])} "
            f"| ({fv(r['z1_se'])}) "
            f"| {fv(r['z1_p'])} "
            f"| {fv(r['z2_coef'], r['z2_p'])} "
            f"| {fv(r['z2_p'])} "
            f"| {fv(r['r2'])} "
            f"| {r['n_obs']} |"
        )

    # Check for offsetting patterns
    lines.append("\n## Offsetting Pattern Check\n")
    asset_post = [r for r in results if r['period'] == 'Post' and 'Asset' in r['dep_var']]
    liab_post = [r for r in results if r['period'] == 'Post' and 'Liabilit' in r['dep_var']]

    if asset_post and liab_post:
        asset_signs = [np.sign(r['z1_coef']) for r in asset_post if not pd.isna(r['z1_coef'])]
        liab_signs = [np.sign(r['z1_coef']) for r in liab_post if not pd.isna(r['z1_coef'])]
        if asset_signs and liab_signs:
            same_sign = all(s == asset_signs[0] for s in asset_signs + liab_signs)
            if same_sign:
                lines.append("*Post-opening: assets and liabilities move in SAME direction → "
                             "no offsetting pattern.*")
            else:
                lines.append("*Post-opening: assets and liabilities move in OPPOSITE directions → "
                             "consistent with countervailing flows.*")

    lines.append("\n*Controls: fiscal_bal_gdp, nfa_gdp_lag, trade_openness, log_rel_opw.*")
    lines.append("*Transition economy openers only. PanelGLS with entity and time effects.*")

    path = OUTPUT_DIR / "phase8_gross_flows.md"
    path.write_text('\n'.join(lines))
    print(f"  Saved: {path}")
    return results


def test2b_ca_components(df):
    """CA component split: savings vs investment, pre vs post."""
    print("\n" + "=" * 70)
    print("TEST 2b: CA COMPONENT SPLIT")
    print("=" * 70)

    _, _, pre, post = get_transition_subsamples(df)

    results = []
    for dep_var, dep_label in CA_COMPONENT_VARS:
        if dep_var not in df.columns:
            continue
        for period_label, period_df in [('Pre', pre), ('Post', post)]:
            r = run_gls(period_df, dep_var, DEMO_VARS + CONTROLS,
                        f"{period_label}: {dep_label}")
            if r:
                results.append({
                    'dep_var': dep_label,
                    'period': period_label,
                    'z1_coef': r.get('Z_1_coef', np.nan),
                    'z1_se': r.get('Z_1_se', np.nan),
                    'z1_p': r.get('Z_1_pval', np.nan),
                    'r2': r['r_squared'],
                    'n_obs': r['n_obs'],
                })

    # Also test Δfx_reserves as capital flight proxy
    print("\n  Testing Δfx_reserves as capital flight proxy...")
    for period_label, period_df in [('Pre', pre), ('Post', post)]:
        pdf = period_df.copy().sort_values(['iso3', 'year'])
        pdf['delta_fx_reserves'] = pdf.groupby('iso3')['fx_reserves_gdp'].diff()
        r = run_gls(pdf, 'delta_fx_reserves', DEMO_VARS + CONTROLS,
                    f"{period_label}: Δ FX Reserves")
        if r:
            results.append({
                'dep_var': 'Δ FX Reserves',
                'period': period_label,
                'z1_coef': r.get('Z_1_coef', np.nan),
                'z1_se': r.get('Z_1_se', np.nan),
                'z1_p': r.get('Z_1_pval', np.nan),
                'r2': r['r_squared'],
                'n_obs': r['n_obs'],
            })

    # Write table
    lines = ["# Test 2b: CA Components and Capital Flight Proxy\n"]
    lines.append("*Z₁ on CA components and reserve changes, pre vs post opening.*\n")

    lines.append("| Component | Period | Z₁ | SE | p | R² | N |")
    lines.append("|:---|:---|---:|---:|---:|---:|---:|")

    for r in results:
        lines.append(
            f"| {r['dep_var']} | {r['period']} "
            f"| {fv(r['z1_coef'], r['z1_p'])} "
            f"| ({fv(r['z1_se'])}) "
            f"| {fv(r['z1_p'])} "
            f"| {fv(r['r2'])} "
            f"| {r['n_obs']} |"
        )

    # Look for offsetting savings/investment
    sav_pre = [r for r in results if r['dep_var'] == 'Gross Savings' and r['period'] == 'Pre']
    inv_pre = [r for r in results if r['dep_var'] == 'Gross Investment' and r['period'] == 'Pre']
    sav_post = [r for r in results if r['dep_var'] == 'Gross Savings' and r['period'] == 'Post']
    inv_post = [r for r in results if r['dep_var'] == 'Gross Investment' and r['period'] == 'Post']

    if sav_pre and inv_pre and sav_post and inv_post:
        lines.append("\n## S-I Offset Pattern\n")
        lines.append(f"*Pre: Z₁→Savings = {fv(sav_pre[0]['z1_coef'])}, "
                     f"Z₁→Investment = {fv(inv_pre[0]['z1_coef'])}*")
        lines.append(f"*Post: Z₁→Savings = {fv(sav_post[0]['z1_coef'])}, "
                     f"Z₁→Investment = {fv(inv_post[0]['z1_coef'])}*")

    lines.append("\n*Controls: fiscal_bal_gdp, nfa_gdp_lag, trade_openness, log_rel_opw.*")
    lines.append("*Transition economy openers only. PanelGLS with entity and time effects.*")

    path = OUTPUT_DIR / "phase8_ca_components.md"
    path.write_text('\n'.join(lines))
    print(f"  Saved: {path}")
    return results


# =====================================================================
# TEST 3: COMPOSITION EFFECTS ACROSS COHORTS
# =====================================================================

def test3a_cohort_attenuation(df):
    """Cohort-specific pre/post Z₁ coefficients."""
    print("\n" + "=" * 70)
    print("TEST 3a: COHORT-SPECIFIC ATTENUATION")
    print("=" * 70)

    trans = df[df['is_transition'] == 1].copy()
    openers = trans[trans['status'] == 'opener'].copy()

    # Split into early (opening_year < 2000) and late (≥ 2000)
    early_isos = openers[openers['opening_year'] < 2000]['iso3'].unique()
    late_isos = openers[openers['opening_year'] >= 2000]['iso3'].unique()

    print(f"  Early openers (<2000): {len(early_isos)} countries: {sorted(early_isos)}")
    print(f"  Late openers (≥2000): {len(late_isos)} countries: {sorted(late_isos)}")

    results = []
    for cohort_label, cohort_isos in [('Early (<2000)', early_isos),
                                       ('Late (≥2000)', late_isos),
                                       ('All openers', openers['iso3'].unique())]:
        cohort_df = openers[openers['iso3'].isin(cohort_isos)]
        for period_label, period_filter in [('Pre', cohort_df['event_time'] < 0),
                                             ('Post', cohort_df['event_time'] >= 0)]:
            period_df = cohort_df[period_filter]
            if len(period_df) < 30:
                print(f"  {cohort_label} {period_label}: too few obs ({len(period_df)})")
                continue

            r = run_gls(period_df, 'ca_gdp', DEMO_VARS + CONTROLS,
                        f"{cohort_label} {period_label}")
            if r:
                results.append({
                    'cohort': cohort_label,
                    'period': period_label,
                    'z1_coef': r.get('Z_1_coef', np.nan),
                    'z1_se': r.get('Z_1_se', np.nan),
                    'z1_p': r.get('Z_1_pval', np.nan),
                    'r2': r['r_squared'],
                    'n_obs': r['n_obs'],
                    'n_countries': r['n_countries'],
                })
                print(f"  {cohort_label} {period_label}: Z₁ = {r.get('Z_1_coef', np.nan):.3f} "
                      f"(p={r.get('Z_1_pval', np.nan):.4f}), N={r['n_obs']}")

    # Compute attenuation ratios per cohort
    lines = ["# Test 3a: Cohort-Specific Pre/Post Z₁ Attenuation\n"]
    lines.append("*If composition effects drive attenuation, the pattern should be "
                 "cohort-specific, not universal.*\n")

    lines.append("| Cohort | Period | Z₁ | SE | p | R² | N | Countries |")
    lines.append("|:---|:---|---:|---:|---:|---:|---:|---:|")

    for r in results:
        lines.append(
            f"| {r['cohort']} | {r['period']} "
            f"| {fv(r['z1_coef'], r['z1_p'])} "
            f"| ({fv(r['z1_se'])}) "
            f"| {fv(r['z1_p'])} "
            f"| {fv(r['r2'])} "
            f"| {r['n_obs']} "
            f"| {r['n_countries']} |"
        )

    # Attenuation summary
    lines.append("\n## Attenuation Ratios\n")
    for cohort in ['Early (<2000)', 'Late (≥2000)', 'All openers']:
        pre_r = [r for r in results if r['cohort'] == cohort and r['period'] == 'Pre']
        post_r = [r for r in results if r['cohort'] == cohort and r['period'] == 'Post']
        if pre_r and post_r:
            z1_pre = pre_r[0]['z1_coef']
            z1_post = post_r[0]['z1_coef']
            if z1_pre != 0 and not pd.isna(z1_pre):
                ratio = z1_post / z1_pre
                lines.append(f"- **{cohort}**: Z₁ pre={fv(z1_pre)}, post={fv(z1_post)}, "
                             f"post/pre ratio={ratio:.2f}")

    lines.append("\n*Controls: fiscal_bal_gdp, nfa_gdp_lag, trade_openness, log_rel_opw.*")
    lines.append("*Transition economy openers only. PanelGLS with entity and time effects.*")

    path = OUTPUT_DIR / "phase8_cohort_attenuation.md"
    path.write_text('\n'.join(lines))
    print(f"  Saved: {path}")
    return results


def test3b_event_time_interaction(df):
    """Z₁ × event_time continuous interaction — gradual vs discrete fade."""
    print("\n" + "=" * 70)
    print("TEST 3b: Z₁ × EVENT_TIME CONTINUOUS INTERACTION")
    print("=" * 70)

    _, openers, _, _ = get_transition_subsamples(df)

    # Limit to reasonable event window
    openers = openers[(openers['event_time'] >= -10) & (openers['event_time'] <= 15)].copy()
    openers['Z_1_x_event_time'] = openers['Z_1'] * openers['event_time']

    results = []

    # Model 1: Z₁ + Z₁×event_time (continuous)
    vars1 = DEMO_VARS + CONTROLS + ['event_time', 'Z_1_x_event_time']
    r1 = run_gls(openers, 'ca_gdp', vars1, "Z₁ × event_time (continuous)")
    if r1:
        results.append(r1)
        print(f"  Z₁ = {r1.get('Z_1_coef', np.nan):.3f} (p={r1.get('Z_1_pval', np.nan):.4f})")
        print(f"  Z₁×event_time = {r1.get('Z_1_x_event_time_coef', np.nan):.4f} "
              f"(p={r1.get('Z_1_x_event_time_pval', np.nan):.4f})")

    # Model 2: Z₁ + Z₁×post (discrete break at opening)
    openers['Z_1_x_post'] = openers['Z_1'] * openers['post_opening']
    vars2 = DEMO_VARS + CONTROLS + ['post_opening', 'Z_1_x_post']
    r2 = run_gls(openers, 'ca_gdp', vars2, "Z₁ × post (discrete)")
    if r2:
        results.append(r2)
        print(f"  Z₁ = {r2.get('Z_1_coef', np.nan):.3f} (p={r2.get('Z_1_pval', np.nan):.4f})")
        print(f"  Z₁×post = {r2.get('Z_1_x_post_coef', np.nan):.3f} "
              f"(p={r2.get('Z_1_x_post_pval', np.nan):.4f})")

    # Write table
    lines = ["# Test 3b: Z₁ × Event Time — Gradual vs Discrete Fade\n"]
    lines.append("*If the demographic effect fades gradually → composition/learning. "
                 "If discrete break at opening → structural disruption.*\n")

    lines.append("## Model 1: Continuous Fade (Z₁ × event_time)\n")
    if r1:
        lines.append("| Variable | Coefficient | SE | p |")
        lines.append("|:---|---:|---:|---:|")
        for v in ['Z_1', 'Z_2', 'Z_3', 'event_time', 'Z_1_x_event_time']:
            coef = r1.get(f'{v}_coef', np.nan)
            se = r1.get(f'{v}_se', np.nan)
            p = r1.get(f'{v}_pval', np.nan)
            lines.append(f"| {v} | {fv(coef, p)} | ({fv(se)}) | {fv(p)} |")
        lines.append(f"\n*R² = {r1['r_squared']:.4f}, N = {r1['n_obs']}*")

    lines.append("\n## Model 2: Discrete Break (Z₁ × post_opening)\n")
    if r2:
        lines.append("| Variable | Coefficient | SE | p |")
        lines.append("|:---|---:|---:|---:|")
        for v in ['Z_1', 'Z_2', 'Z_3', 'post_opening', 'Z_1_x_post']:
            coef = r2.get(f'{v}_coef', np.nan)
            se = r2.get(f'{v}_se', np.nan)
            p = r2.get(f'{v}_pval', np.nan)
            lines.append(f"| {v} | {fv(coef, p)} | ({fv(se)}) | {fv(p)} |")
        lines.append(f"\n*R² = {r2['r_squared']:.4f}, N = {r2['n_obs']}*")

    # Interpretation
    if r1 and r2:
        cont_p = r1.get('Z_1_x_event_time_pval', np.nan)
        disc_p = r2.get('Z_1_x_post_pval', np.nan)
        lines.append("\n## Interpretation\n")
        lines.append(f"- Continuous fade Z₁×event_time: p = {fv(cont_p)}")
        lines.append(f"- Discrete break Z₁×post: p = {fv(disc_p)}")
        if not pd.isna(cont_p) and not pd.isna(disc_p):
            if cont_p < 0.1 and disc_p >= 0.1:
                lines.append("*Gradual fade significant, discrete break not → composition/learning effect.*")
            elif disc_p < 0.1 and cont_p >= 0.1:
                lines.append("*Discrete break significant, gradual fade not → structural disruption at opening.*")
            elif cont_p < 0.1 and disc_p < 0.1:
                lines.append("*Both significant → both mechanisms may operate.*")
            else:
                lines.append("*Neither significant → insufficient power to distinguish.*")

    lines.append("\n*Event window: [-10, +15]. Transition economy openers only.*")
    lines.append("*Controls: fiscal_bal_gdp, nfa_gdp_lag, trade_openness, log_rel_opw.*")

    path = OUTPUT_DIR / "phase8_event_time_interaction.md"
    path.write_text('\n'.join(lines))
    print(f"  Saved: {path}")
    return results


def test3c_trajectory_divergence(df):
    """Demographic trajectory divergence across early/late cohorts."""
    print("\n" + "=" * 70)
    print("TEST 3c: DEMOGRAPHIC TRAJECTORY DIVERGENCE")
    print("=" * 70)

    trans = df[df['is_transition'] == 1].copy()
    openers = trans[trans['status'] == 'opener'].copy()

    early_isos = openers[openers['opening_year'] < 2000]['iso3'].unique()
    late_isos = openers[openers['opening_year'] >= 2000]['iso3'].unique()

    event_window = range(-10, 16)
    cohort_data = {}

    for cohort_label, cohort_isos in [('Early (<2000)', early_isos),
                                       ('Late (≥2000)', late_isos)]:
        cohort_df = openers[openers['iso3'].isin(cohort_isos)]
        means = {}
        for e in event_window:
            vals = cohort_df[cohort_df['event_time'] == e]['Z_1']
            if len(vals) >= 2:
                means[e] = (vals.mean(), vals.std(), len(vals))
        cohort_data[cohort_label] = means

    # Write table
    lines = ["# Test 3c: Demographic Trajectory Divergence Across Cohorts\n"]
    lines.append("*If early and late openers have diverging Z₁ paths, pooling creates "
                 "mechanical attenuation.*\n")

    lines.append("| Event Time | Early Z₁ Mean | (SD) | N | Late Z₁ Mean | (SD) | N | Difference |")
    lines.append("|---:|---:|---:|---:|---:|---:|---:|---:|")

    for e in event_window:
        early = cohort_data.get('Early (<2000)', {}).get(e, (np.nan, np.nan, 0))
        late = cohort_data.get('Late (≥2000)', {}).get(e, (np.nan, np.nan, 0))
        diff = early[0] - late[0] if not pd.isna(early[0]) and not pd.isna(late[0]) else np.nan
        marker = " **←open**" if e == 0 else ""
        lines.append(
            f"| {e:+d}{marker} "
            f"| {fv(early[0])} | ({fv(early[1])}) | {early[2]} "
            f"| {fv(late[0])} | ({fv(late[1])}) | {late[2]} "
            f"| {fv(diff)} |"
        )

    # Test for divergence: is the early-late gap widening over event time?
    diffs_pre = []
    diffs_post = []
    for e in event_window:
        early = cohort_data.get('Early (<2000)', {}).get(e)
        late = cohort_data.get('Late (≥2000)', {}).get(e)
        if early and late and not pd.isna(early[0]) and not pd.isna(late[0]):
            d = early[0] - late[0]
            if e < 0:
                diffs_pre.append(d)
            else:
                diffs_post.append(d)

    if diffs_pre and diffs_post:
        avg_pre_diff = np.mean(diffs_pre)
        avg_post_diff = np.mean(diffs_post)
        lines.append(f"\n**Average early-late Z₁ gap: pre = {avg_pre_diff:.3f}, "
                     f"post = {avg_post_diff:.3f}**")
        if abs(avg_post_diff) > abs(avg_pre_diff) * 1.5:
            lines.append("*Gap widens post-opening → diverging trajectories support "
                         "composition hypothesis.*")
        else:
            lines.append("*Gap stable or narrows → composition not the primary driver.*")

    lines.append("\n*Transition economy openers only. Early: opening before 2000. "
                 "Late: opening 2000 or after.*")

    path = OUTPUT_DIR / "phase8_trajectory_divergence.md"
    path.write_text('\n'.join(lines))
    print(f"  Saved: {path}")
    return cohort_data


# =====================================================================
# SUMMARY
# =====================================================================

def write_summary(t1a_results, t1b_results, t1c_results, t2a_results, t2b_results,
                  t3a_results, t3b_results):
    """Write summary table of all three mechanism tests with verdicts."""
    print("\n" + "=" * 70)
    print("WRITING SUMMARY TABLE")
    print("=" * 70)

    lines = ["# Phase 8 Summary: Section 9.2 Hypothesis Tests\n"]
    lines.append("*Three mechanisms for why Z₁'s CA coefficient drops 9:1 after "
                 "capital account opening in transition economies.*\n")

    # ── Test 1 verdict ──
    lines.append("## Test 1: Disruption of Structural Confounders\n")
    lines.append("*Pre-opening, Z₁ proxies for Soviet-era structural characteristics. "
                 "Opening breaks this correlation.*\n")

    t1_verdict_parts = []

    if t1a_results and len(t1a_results) == 2:
        r2_pre = t1a_results[0]['r2_full']
        r2_post = t1a_results[1]['r2_full']
        ratio = r2_pre / r2_post if r2_post > 0 else np.inf
        lines.append(f"- **1a (Partial R²)**: Pre R² = {r2_pre:.4f}, Post R² = {r2_post:.4f} "
                     f"(ratio = {ratio:.2f}×)")
        t1_verdict_parts.append('R²_pre > R²_post' if r2_pre > r2_post else 'R²_pre ≤ R²_post')

    if t1b_results and len(t1b_results) == 2:
        att_pre = t1b_results[0]['attenuation_pct']
        att_post = t1b_results[1]['attenuation_pct']
        lines.append(f"- **1b (Mediation)**: Pre attenuation = {att_pre:.1f}%, "
                     f"Post attenuation = {att_post:.1f}%")
        t1_verdict_parts.append('more attenuation pre' if att_pre > att_post else 'more attenuation post')

    if t1c_results and len(t1c_results) == 2:
        d_pre = t1c_results[0]['delta']
        d_post = t1c_results[1]['delta']
        lines.append(f"- **1c (Oster δ)**: δ_pre = {fv(d_pre)}, δ_post = {fv(d_post)}")
        if not pd.isna(d_pre) and not pd.isna(d_post):
            t1_verdict_parts.append('|δ_pre| < |δ_post|' if abs(d_pre) < abs(d_post)
                                     else '|δ_pre| ≥ |δ_post|')

    # Determine verdict
    supports_1 = sum(1 for p in t1_verdict_parts
                     if any(x in p for x in ['R²_pre > R²_post', 'more attenuation pre',
                                              '|δ_pre| < |δ_post|']))
    if supports_1 >= 2:
        lines.append("\n**Verdict: SUPPORTED** — Majority of tests consistent with confounders hypothesis.")
    elif supports_1 == 1:
        lines.append("\n**Verdict: MIXED** — Some tests consistent, others not.")
    else:
        lines.append("\n**Verdict: NOT SUPPORTED** — Tests do not support confounders hypothesis.")

    # ── Test 2 verdict ──
    lines.append("\n## Test 2: Countervailing Capital Flows\n")
    lines.append("*Opening enables both inflows and outflows that offset the lifecycle "
                 "pattern in net terms.*\n")

    # Check if assets and liabilities have offsetting Z₁ signs post-opening
    post_assets = [r for r in t2a_results if r['period'] == 'Post' and 'Asset' in r['dep_var']]
    post_liabs = [r for r in t2a_results if r['period'] == 'Post' and 'Liabilit' in r['dep_var']]

    asset_sig = sum(1 for r in post_assets if not pd.isna(r['z1_p']) and r['z1_p'] < 0.1)
    liab_sig = sum(1 for r in post_liabs if not pd.isna(r['z1_p']) and r['z1_p'] < 0.1)

    offsetting = False
    if post_assets and post_liabs:
        a_signs = [np.sign(r['z1_coef']) for r in post_assets if not pd.isna(r['z1_coef'])]
        l_signs = [np.sign(r['z1_coef']) for r in post_liabs if not pd.isna(r['z1_coef'])]
        if a_signs and l_signs:
            # Check if broad aggregates have opposite signs
            offsetting = (a_signs[0] != l_signs[0])  # gross_assets vs gross_liab

    lines.append(f"- **2a (Gross flows)**: {asset_sig}/{len(post_assets)} asset components significant, "
                 f"{liab_sig}/{len(post_liabs)} liability components significant post-opening")
    if offsetting:
        lines.append("  - Assets and liabilities move in opposite directions → offsetting pattern")
    else:
        lines.append("  - No clear offsetting pattern in gross flows")

    # Check CA components
    ca_pre = [r for r in t2b_results if r['dep_var'] == 'Current Account' and r['period'] == 'Pre']
    ca_post = [r for r in t2b_results if r['dep_var'] == 'Current Account' and r['period'] == 'Post']
    sav_results = [r for r in t2b_results if 'Savings' in r['dep_var']]
    inv_results = [r for r in t2b_results if 'Investment' in r['dep_var']]

    if sav_results and inv_results:
        lines.append(f"- **2b (CA components)**: Savings and investment channels tested")

    dfx = [r for r in t2b_results if 'FX Reserves' in r['dep_var']]
    if dfx:
        lines.append(f"- **2c (Capital flight)**: Δ FX reserves proxy tested")

    if offsetting:
        lines.append("\n**Verdict: SUPPORTED** — Evidence of countervailing gross flows post-opening.")
    else:
        lines.append("\n**Verdict: NOT SUPPORTED** — No clear offsetting pattern in gross flows.")

    # ── Test 3 verdict ──
    lines.append("\n## Test 3: Composition Effects Across Cohorts\n")
    lines.append("*Early openers (Baltics) vs late openers (Central Asia) have "
                 "different trajectories; pooling obscures patterns.*\n")

    # Check if attenuation is cohort-specific
    early_pre = [r for r in t3a_results if r['cohort'] == 'Early (<2000)' and r['period'] == 'Pre']
    early_post = [r for r in t3a_results if r['cohort'] == 'Early (<2000)' and r['period'] == 'Post']
    late_pre = [r for r in t3a_results if r['cohort'] == 'Late (≥2000)' and r['period'] == 'Pre']
    late_post = [r for r in t3a_results if r['cohort'] == 'Late (≥2000)' and r['period'] == 'Post']

    if early_pre and early_post:
        e_ratio = early_post[0]['z1_coef'] / early_pre[0]['z1_coef'] if early_pre[0]['z1_coef'] != 0 else np.nan
        lines.append(f"- **Early openers**: Z₁ pre = {fv(early_pre[0]['z1_coef'])}, "
                     f"post = {fv(early_post[0]['z1_coef'])}, ratio = {fv(e_ratio)}")
    if late_pre and late_post:
        l_ratio = late_post[0]['z1_coef'] / late_pre[0]['z1_coef'] if late_pre[0]['z1_coef'] != 0 else np.nan
        lines.append(f"- **Late openers**: Z₁ pre = {fv(late_pre[0]['z1_coef'])}, "
                     f"post = {fv(late_post[0]['z1_coef'])}, ratio = {fv(l_ratio)}")

    # Check event_time interaction
    if t3b_results:
        for r in t3b_results:
            if 'Z_1_x_event_time_coef' in r:
                lines.append(f"- **Continuous fade**: Z₁×event_time = "
                             f"{fv(r['Z_1_x_event_time_coef'], r.get('Z_1_x_event_time_pval'))}")
            if 'Z_1_x_post_coef' in r:
                lines.append(f"- **Discrete break**: Z₁×post = "
                             f"{fv(r['Z_1_x_post_coef'], r.get('Z_1_x_post_pval'))}")

    # Composition verdict: universal attenuation argues against composition
    both_attenuate = (early_pre and early_post and late_pre and late_post and
                      early_post[0]['z1_coef'] != 0 and late_post[0]['z1_coef'] != 0)
    if both_attenuate:
        e_att = abs(early_post[0]['z1_coef']) < abs(early_pre[0]['z1_coef'])
        l_att = abs(late_post[0]['z1_coef']) < abs(late_pre[0]['z1_coef'])
        if e_att and l_att:
            lines.append("\n**Verdict: NOT SUPPORTED** — Attenuation is universal across cohorts, "
                         "not cohort-specific.")
        elif e_att != l_att:
            lines.append("\n**Verdict: PARTIALLY SUPPORTED** — Attenuation differs across cohorts.")
        else:
            lines.append("\n**Verdict: INCONCLUSIVE** — Cannot determine pattern from available data.")
    else:
        lines.append("\n**Verdict: INCONCLUSIVE** — Insufficient data for both cohorts.")

    # ── Overall verdict ──
    lines.append("\n---\n")
    lines.append("## Overall Assessment\n")
    lines.append("| Mechanism | Key Evidence | Verdict |")
    lines.append("|:---|:---|:---|")
    lines.append("| 1. Structural confounders | Partial R², mediation, Oster bounds | See above |")
    lines.append("| 2. Countervailing flows | Gross flow decomposition | See above |")
    lines.append("| 3. Composition effects | Cohort-specific attenuation, trajectory divergence | See above |")

    lines.append("\n*All tests: PanelGLS with entity and time effects. "
                 "Transition economy openers only.*")

    path = OUTPUT_DIR / "phase8_summary.md"
    path.write_text('\n'.join(lines))
    print(f"  Saved: {path}")

    # Also save CSV summary
    csv_rows = []
    for r in (t1a_results or []):
        csv_rows.append({'test': '1a_partial_r2', 'period': r['period'],
                         'key_stat': r['r2_full'], 'stat_name': 'r2_full'})
    for r in (t1b_results or []):
        csv_rows.append({'test': '1b_mediation', 'period': r['period'],
                         'key_stat': r['attenuation_pct'], 'stat_name': 'attenuation_pct'})
    for r in (t1c_results or []):
        csv_rows.append({'test': '1c_oster', 'period': r['period'],
                         'key_stat': r['delta'], 'stat_name': 'delta'})
    for r in (t2a_results or []):
        csv_rows.append({'test': '2a_gross_flows', 'period': r['period'],
                         'dep_var': r['dep_var'],
                         'key_stat': r['z1_coef'], 'stat_name': 'z1_coef',
                         'p_value': r['z1_p']})
    for r in (t2b_results or []):
        csv_rows.append({'test': '2b_ca_components', 'period': r['period'],
                         'dep_var': r['dep_var'],
                         'key_stat': r['z1_coef'], 'stat_name': 'z1_coef',
                         'p_value': r['z1_p']})
    for r in (t3a_results or []):
        csv_rows.append({'test': '3a_cohort', 'period': r['period'],
                         'dep_var': r['cohort'],
                         'key_stat': r['z1_coef'], 'stat_name': 'z1_coef',
                         'p_value': r['z1_p']})

    csv_path = OUTPUT_DIR / "phase8_summary.csv"
    pd.DataFrame(csv_rows).to_csv(csv_path, index=False)
    print(f"  Saved: {csv_path}")


# =====================================================================
# MAIN
# =====================================================================

def main():
    print("=" * 70)
    print("PHASE 8: SECTION 9.2 HYPOTHESIS TESTS")
    print("Why the Demographic Channel Weakens Post-Opening")
    print("=" * 70)

    df = load_panel()
    print(f"Panel: {len(df)} obs, {df['iso3'].nunique()} countries, "
          f"years {df['year'].min()}-{df['year'].max()}")

    trans = df[df['is_transition'] == 1]
    openers = trans[trans['status'] == 'opener']
    print(f"Transition: {len(trans)} obs, {trans['iso3'].nunique()} countries")
    print(f"Openers: {len(openers)} obs, {openers['iso3'].nunique()} countries")
    print(f"Pre-opening: {(openers['event_time'] < 0).sum()} obs")
    print(f"Post-opening: {(openers['event_time'] >= 0).sum()} obs")

    # Test 1: Structural confounders
    t1a = test1a_partial_r2(df)
    t1b = test1b_mediation(df)
    t1c = test1c_oster_bounds(df)

    # Test 2: Countervailing capital flows
    t2a = test2a_gross_flows(df)
    t2b = test2b_ca_components(df)

    # Test 3: Composition effects
    t3a = test3a_cohort_attenuation(df)
    t3b = test3b_event_time_interaction(df)
    t3c = test3c_trajectory_divergence(df)

    # Summary
    write_summary(t1a, t1b, t1c, t2a, t2b, t3a, t3b)

    print("\n" + "=" * 70)
    print("PHASE 8 COMPLETE — All tables saved to:")
    print(f"  {OUTPUT_DIR}")
    print("=" * 70)


if __name__ == '__main__':
    main()
